1 package org.apache.lucene.search;
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 import org.apache.lucene.util.PriorityQueue;
21
22 import java.io.IOException;
23
24
25
26 public class TopDocs {
27
28
29 public int totalHits;
30
31
32 public ScoreDoc[] scoreDocs;
33
34
35 private float maxScore;
36
37
38
39
40
41 public float getMaxScore() {
42 return maxScore;
43 }
44
45
46 public void setMaxScore(float maxScore) {
47 this.maxScore = maxScore;
48 }
49
50
51 TopDocs(int totalHits, ScoreDoc[] scoreDocs) {
52 this(totalHits, scoreDocs, Float.NaN);
53 }
54
55 public TopDocs(int totalHits, ScoreDoc[] scoreDocs, float maxScore) {
56 this.totalHits = totalHits;
57 this.scoreDocs = scoreDocs;
58 this.maxScore = maxScore;
59 }
60
61
62 private static class ShardRef {
63
64 final int shardIndex;
65
66
67 int hitIndex;
68
69 public ShardRef(int shardIndex) {
70 this.shardIndex = shardIndex;
71 }
72
73 @Override
74 public String toString() {
75 return "ShardRef(shardIndex=" + shardIndex + " hitIndex=" + hitIndex + ")";
76 }
77 };
78
79
80
81 private static class ScoreMergeSortQueue extends PriorityQueue<ShardRef> {
82 final ScoreDoc[][] shardHits;
83
84 public ScoreMergeSortQueue(TopDocs[] shardHits) {
85 super(shardHits.length);
86 this.shardHits = new ScoreDoc[shardHits.length][];
87 for(int shardIDX=0;shardIDX<shardHits.length;shardIDX++) {
88 this.shardHits[shardIDX] = shardHits[shardIDX].scoreDocs;
89 }
90 }
91
92
93 @Override
94 public boolean lessThan(ShardRef first, ShardRef second) {
95 assert first != second;
96 final float firstScore = shardHits[first.shardIndex][first.hitIndex].score;
97 final float secondScore = shardHits[second.shardIndex][second.hitIndex].score;
98
99 if (firstScore < secondScore) {
100 return false;
101 } else if (firstScore > secondScore) {
102 return true;
103 } else {
104
105 if (first.shardIndex < second.shardIndex) {
106 return true;
107 } else if (first.shardIndex > second.shardIndex) {
108 return false;
109 } else {
110
111
112 assert first.hitIndex != second.hitIndex;
113 return first.hitIndex < second.hitIndex;
114 }
115 }
116 }
117 }
118
119 @SuppressWarnings({"rawtypes","unchecked"})
120 private static class MergeSortQueue extends PriorityQueue<ShardRef> {
121
122 final ScoreDoc[][] shardHits;
123 final FieldComparator<?>[] comparators;
124 final int[] reverseMul;
125
126 public MergeSortQueue(Sort sort, TopDocs[] shardHits) throws IOException {
127 super(shardHits.length);
128 this.shardHits = new ScoreDoc[shardHits.length][];
129 for(int shardIDX=0;shardIDX<shardHits.length;shardIDX++) {
130 final ScoreDoc[] shard = shardHits[shardIDX].scoreDocs;
131
132 if (shard != null) {
133 this.shardHits[shardIDX] = shard;
134
135 for(int hitIDX=0;hitIDX<shard.length;hitIDX++) {
136 final ScoreDoc sd = shard[hitIDX];
137 if (!(sd instanceof FieldDoc)) {
138 throw new IllegalArgumentException("shard " + shardIDX + " was not sorted by the provided Sort (expected FieldDoc but got ScoreDoc)");
139 }
140 final FieldDoc fd = (FieldDoc) sd;
141 if (fd.fields == null) {
142 throw new IllegalArgumentException("shard " + shardIDX + " did not set sort field values (FieldDoc.fields is null); you must pass fillFields=true to IndexSearcher.search on each shard");
143 }
144 }
145 }
146 }
147
148 final SortField[] sortFields = sort.getSort();
149 comparators = new FieldComparator[sortFields.length];
150 reverseMul = new int[sortFields.length];
151 for(int compIDX=0;compIDX<sortFields.length;compIDX++) {
152 final SortField sortField = sortFields[compIDX];
153 comparators[compIDX] = sortField.getComparator(1, compIDX);
154 reverseMul[compIDX] = sortField.getReverse() ? -1 : 1;
155 }
156 }
157
158
159 @Override
160 public boolean lessThan(ShardRef first, ShardRef second) {
161 assert first != second;
162 final FieldDoc firstFD = (FieldDoc) shardHits[first.shardIndex][first.hitIndex];
163 final FieldDoc secondFD = (FieldDoc) shardHits[second.shardIndex][second.hitIndex];
164
165
166 for(int compIDX=0;compIDX<comparators.length;compIDX++) {
167 final FieldComparator comp = comparators[compIDX];
168
169
170 final int cmp = reverseMul[compIDX] * comp.compareValues(firstFD.fields[compIDX], secondFD.fields[compIDX]);
171
172 if (cmp != 0) {
173
174 return cmp < 0;
175 }
176 }
177
178
179 if (first.shardIndex < second.shardIndex) {
180
181 return true;
182 } else if (first.shardIndex > second.shardIndex) {
183
184 return false;
185 } else {
186
187
188
189 assert first.hitIndex != second.hitIndex;
190 return first.hitIndex < second.hitIndex;
191 }
192 }
193 }
194
195
196
197
198
199 public static TopDocs merge(int topN, TopDocs[] shardHits) throws IOException {
200 return merge(0, topN, shardHits);
201 }
202
203
204
205
206
207
208 public static TopDocs merge(int start, int topN, TopDocs[] shardHits) throws IOException {
209 return mergeAux(null, start, topN, shardHits);
210 }
211
212
213
214
215
216
217
218
219 public static TopFieldDocs merge(Sort sort, int topN, TopFieldDocs[] shardHits) throws IOException {
220 return merge(sort, 0, topN, shardHits);
221 }
222
223
224
225
226
227
228 public static TopFieldDocs merge(Sort sort, int start, int topN, TopFieldDocs[] shardHits) throws IOException {
229 if (sort == null) {
230 throw new IllegalArgumentException("sort must be non-null when merging field-docs");
231 }
232 return (TopFieldDocs) mergeAux(sort, start, topN, shardHits);
233 }
234
235
236
237 private static TopDocs mergeAux(Sort sort, int start, int size, TopDocs[] shardHits) throws IOException {
238 final PriorityQueue<ShardRef> queue;
239 if (sort == null) {
240 queue = new ScoreMergeSortQueue(shardHits);
241 } else {
242 queue = new MergeSortQueue(sort, shardHits);
243 }
244
245 int totalHitCount = 0;
246 int availHitCount = 0;
247 float maxScore = Float.MIN_VALUE;
248 for(int shardIDX=0;shardIDX<shardHits.length;shardIDX++) {
249 final TopDocs shard = shardHits[shardIDX];
250
251
252 totalHitCount += shard.totalHits;
253 if (shard.scoreDocs != null && shard.scoreDocs.length > 0) {
254 availHitCount += shard.scoreDocs.length;
255 queue.add(new ShardRef(shardIDX));
256 maxScore = Math.max(maxScore, shard.getMaxScore());
257
258 }
259 }
260
261 if (availHitCount == 0) {
262 maxScore = Float.NaN;
263 }
264
265 final ScoreDoc[] hits;
266 if (availHitCount <= start) {
267 hits = new ScoreDoc[0];
268 } else {
269 hits = new ScoreDoc[Math.min(size, availHitCount - start)];
270 int requestedResultWindow = start + size;
271 int numIterOnHits = Math.min(availHitCount, requestedResultWindow);
272 int hitUpto = 0;
273 while (hitUpto < numIterOnHits) {
274 assert queue.size() > 0;
275 ShardRef ref = queue.top();
276 final ScoreDoc hit = shardHits[ref.shardIndex].scoreDocs[ref.hitIndex++];
277 hit.shardIndex = ref.shardIndex;
278 if (hitUpto >= start) {
279 hits[hitUpto - start] = hit;
280 }
281
282
283
284
285 hitUpto++;
286
287 if (ref.hitIndex < shardHits[ref.shardIndex].scoreDocs.length) {
288
289 queue.updateTop();
290 } else {
291 queue.pop();
292 }
293 }
294 }
295
296 if (sort == null) {
297 return new TopDocs(totalHitCount, hits, maxScore);
298 } else {
299 return new TopFieldDocs(totalHitCount, hits, sort.getSort(), maxScore);
300 }
301 }
302 }